#include "solver.h"

#include <cassert>
#include <cmath>
#include <cstdio>
#include <limits>

template<typename T>
bool IsZero(T t) {
  const T epsilon = std::numeric_limits<T>::epsilon();
  return t > -epsilon && t < epsilon;
}

namespace raytrace {
namespace solver {

void SolveQuadric(const std::vector<float>& coeffs,
                  std::vector<float>& solutions) {
    // Normal form: x^2 + px + q = 0
    const float p = coeffs[1] / (2 * coeffs[0]);
    const float q = coeffs[2] / coeffs[0];

    const float det = p * p - q;

    if (IsZero(det)) {
	    solutions.push_back(-p);
    } else if (det > 0) {
  	  const float sqrt_det = sqrtf(det);

    	solutions.push_back(sqrt_det - p);
      solutions.push_back(-sqrt_det - p);
    }
}

void SolveCubic(const std::vector<float>& coeffs,
                std::vector<float>& solutions) {
  // Normal form: x^3 + Ax^2 + Bx + C = 0
  const float A = coeffs[1] / coeffs[0];
  const float B = coeffs[2] / coeffs[0];
  const float C = coeffs[3] / coeffs[0];

  // Substitute x = y - A/3 to eliminate quadric term: x^3 + px + q = 0.

  const float sq_A = A * A;
  const float p = 1.0f / 3 * (-1.0f / 3 * sq_A + B);
  const float q = 1.0f / 2 * (2.0f / 27 * A * sq_A - 1.0f / 3 * A * B + C);

  // use Cardano's formula
  const float cb_p = p * p * p;
  const float det = q * q + cb_p;

  if (IsZero(det)) {
    if (IsZero(q)) {
      solutions.push_back(0);
    } else {
      const float u = cbrt(-q);
      solutions.push_back(2 * u);
      solutions.push_back(-u);
    }
  } else if (det < 0) {
    const float phi = 1.0f/3 * acosf(-q / sqrtf(-cb_p));
    const float t = 2 * sqrtf(-p);

    solutions.push_back(t * cosf(phi));
    solutions.push_back(t * cosf(phi + M_PI / 3));
    solutions.push_back(t * cosf(phi - M_PI / 3));
  }
  else
  {
    const float sqrt_det = sqrtf(det);
    const float u = cbrt(sqrt_det - q);
    const float v = -cbrt(sqrt_det + q);

    solutions.push_back(u + v);
  }

  // Resubstitute
  const float sub = 1.0f/3 * A;

  for (std::vector<float>::iterator it = solutions.begin();
       it != solutions.end();
       ++it) {
    *it = *it - sub;
  }
}

void SolveQuartic(const std::vector<float>& coeffs, 
                  std::vector<float>& solutions) {
  // Normal form: x^4 + Ax^3 + Bx^2 + Cx + D = 0
  const float A = coeffs[1] / coeffs[0];
  const float B = coeffs[2] / coeffs[0];
  const float C = coeffs[3] / coeffs[0];
  const float D = coeffs[4] / coeffs[0];

  // Substitute x = y - A/4 to eliminate cubic term: x^4 + px^2 + qx + r = 0.

  const float sq_A = A * A;

  const float p = -3.0f / 8 * sq_A + B;
  const float q = 1.0f / 8 * sq_A * A - 1.0f / 2 * A * B + C;
  const float r = -3.0f / 256 * sq_A * sq_A + 1.0f / 16 * sq_A * B -
                  1.0f/4 * A * C + D;

  if (IsZero(r)) {
    // No absolute term: y(y^3 + py + q) = 0
    std::vector<float> cubic_c;
    cubic_c.push_back(1);
    cubic_c.push_back(0);
    cubic_c.push_back(p);
    cubic_c.push_back(q);

    SolveCubic(cubic_c, solutions);

    solutions.push_back(0);
  } else {
    // Solve the resolvent cubic: y^3 - (p/2)x^2 - r.x + r.p/2 - (q^2)/8.
    std::vector<float> cubic_c;
    cubic_c.push_back(1);
    cubic_c.push_back(-p / 2.0f);
    cubic_c.push_back(-r);
    cubic_c.push_back((r * p / 2.0f) - (q * q / 8.0f));

    std::vector<float> cubic_s;
    SolveCubic(cubic_c, cubic_s);

    // And take the one real solution to build two quadric equations.
    const float z = cubic_s[0];

    float u = z * z - r;
    float v = 2 * z - p;

    if (IsZero(u))
        u = 0;
    else if (u > 0)
        u = sqrtf(u);
    else
      return;

    if (IsZero(v))
      v = 0;
    else if (v > 0)
      v = sqrtf(v);
    else
      return;

    std::vector<float> quadric_c;
    quadric_c.push_back(1);
    quadric_c.push_back(q < 0 ? -v : v);
    quadric_c.push_back(z - u);

    SolveQuadric(quadric_c, solutions);

    quadric_c.clear();
    quadric_c.push_back(1);
    quadric_c.push_back(q < 0 ? v : -v);
    quadric_c.push_back(z + u);

    SolveQuadric(quadric_c, solutions);
  }

  // Resubstitute

  const float sub = A / 4.0f;

  for (std::vector<float>::iterator it = solutions.begin();
       it != solutions.end();
       ++it) {
    *it = *it - sub;
  }
}
void Solve(const std::vector<float>& coeffs,
           std::vector<float>& solutions) {
  solutions.clear();

  switch(coeffs.size()) {
    case 0:
    case 1:
      solutions.push_back(0);
      break;
    case 2:
      solutions.push_back(-coeffs[1] / coeffs[0]);
      break;
    case 3:
      SolveQuadric(coeffs, solutions);
      break;
    case 4:
      SolveCubic(coeffs, solutions);
      break;
    case 5:
      SolveQuartic(coeffs, solutions);
      break;
    default:
      assert(false && "Unable to solve higher order than quartics");
      break;
  };
}

} // end namespace
} // end namespace

